{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# ⚡️ Energy-Based Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2531aef5-c81a-4b53-a344-4b979dd4eec5",
   "metadata": {},
   "source": [
    "The code is adapted from the excellent ['Deep Energy-Based Generative Models' tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html) created by Phillip Lippe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    datasets,\n",
    "    layers,\n",
    "    models,\n",
    "    optimizers,\n",
    "    activations,\n",
    "    metrics,\n",
    "    callbacks,\n",
    ")\n",
    "\n",
    "from notebooks.utils import display, sample_batch\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 32\n",
    "CHANNELS = 1\n",
    "STEP_SIZE = 10\n",
    "STEPS = 60\n",
    "NOISE = 0.005\n",
    "ALPHA = 0.1\n",
    "GRADIENT_CLIP = 0.03\n",
    "BATCH_SIZE = 128\n",
    "BUFFER_SIZE = 8192\n",
    "LEARNING_RATE = 0.0001\n",
    "EPOCHS = 60\n",
    "LOAD_MODEL = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "(x_train, _), (x_test, _) = datasets.mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20697102-8c8d-4582-88d4-f8e2af84e060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "\n",
    "\n",
    "def preprocess(imgs):\n",
    "    \"\"\"\n",
    "    Normalize and reshape the images\n",
    "    \"\"\"\n",
    "    imgs = (imgs.astype(\"float32\") - 127.5) / 127.5\n",
    "    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=-1.0)\n",
    "    imgs = np.expand_dims(imgs, -1)\n",
    "    return imgs\n",
    "\n",
    "\n",
    "x_train = preprocess(x_train)\n",
    "x_test = preprocess(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13668819-2e42-4661-8682-33ff2c24ae8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = tf.data.Dataset.from_tensor_slices(x_train).batch(BATCH_SIZE)\n",
    "x_test = tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e1a420-699e-4869-8d10-3c049dbad030",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some items of clothing from the training set\n",
    "train_sample = sample_batch(x_train)\n",
    "display(train_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f53945d9-b7c5-49d0-a356-bcf1d1e1798b",
   "metadata": {},
   "source": [
    "## 2. Build the EBM network <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8936d951-3281-4424-9cce-59433976bf2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ebm_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
    "x = layers.Conv2D(\n",
    "    16, kernel_size=5, strides=2, padding=\"same\", activation=activations.swish\n",
    ")(ebm_input)\n",
    "x = layers.Conv2D(\n",
    "    32, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
    ")(x)\n",
    "x = layers.Conv2D(\n",
    "    64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
    ")(x)\n",
    "x = layers.Conv2D(\n",
    "    64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
    ")(x)\n",
    "x = layers.Flatten()(x)\n",
    "x = layers.Dense(64, activation=activations.swish)(x)\n",
    "ebm_output = layers.Dense(1)(x)\n",
    "model = models.Model(ebm_input, ebm_output)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32221908-8819-48fa-8e57-0dc5179ca2cf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    model.load_weights(\"./models/model.h5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f392424-45a9-49cc-8ea0-c1bec9064d74",
   "metadata": {},
   "source": [
    "## 2. Set up a Langevin sampler function <a name=\"sampler\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf10775a-0fbf-42df-aca5-be4b256a0c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to generate samples using Langevin Dynamics\n",
    "def generate_samples(\n",
    "    model, inp_imgs, steps, step_size, noise, return_img_per_step=False\n",
    "):\n",
    "    imgs_per_step = []\n",
    "    for _ in range(steps):\n",
    "        inp_imgs += tf.random.normal(inp_imgs.shape, mean=0, stddev=noise)\n",
    "        inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n",
    "        with tf.GradientTape() as tape:\n",
    "            tape.watch(inp_imgs)\n",
    "            out_score = model(inp_imgs)\n",
    "        grads = tape.gradient(out_score, inp_imgs)\n",
    "        grads = tf.clip_by_value(grads, -GRADIENT_CLIP, GRADIENT_CLIP)\n",
    "        inp_imgs += step_size * grads\n",
    "        inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n",
    "        if return_img_per_step:\n",
    "            imgs_per_step.append(inp_imgs)\n",
    "    if return_img_per_step:\n",
    "        return tf.stack(imgs_per_step, axis=0)\n",
    "    else:\n",
    "        return inp_imgs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "180fb0a1-ed16-47c2-b326-ad66071cd6e2",
   "metadata": {},
   "source": [
    "## 3. Set up a buffer to store examples <a name=\"buffer\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52615dcd-be2b-4e05-b729-0ec45ea6ef98",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Buffer:\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.examples = [\n",
    "            tf.random.uniform(shape=(1, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2\n",
    "            - 1\n",
    "            for _ in range(BATCH_SIZE)\n",
    "        ]\n",
    "\n",
    "    def sample_new_exmps(self, steps, step_size, noise):\n",
    "        n_new = np.random.binomial(BATCH_SIZE, 0.05)\n",
    "        rand_imgs = (\n",
    "            tf.random.uniform((n_new, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n",
    "        )\n",
    "        old_imgs = tf.concat(\n",
    "            random.choices(self.examples, k=BATCH_SIZE - n_new), axis=0\n",
    "        )\n",
    "        inp_imgs = tf.concat([rand_imgs, old_imgs], axis=0)\n",
    "        inp_imgs = generate_samples(\n",
    "            self.model, inp_imgs, steps=steps, step_size=step_size, noise=noise\n",
    "        )\n",
    "        self.examples = tf.split(inp_imgs, BATCH_SIZE, axis=0) + self.examples\n",
    "        self.examples = self.examples[:BUFFER_SIZE]\n",
    "        return inp_imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EBM(models.Model):\n",
    "    def __init__(self):\n",
    "        super(EBM, self).__init__()\n",
    "        self.model = model\n",
    "        self.buffer = Buffer(self.model)\n",
    "        self.alpha = ALPHA\n",
    "        self.loss_metric = metrics.Mean(name=\"loss\")\n",
    "        self.reg_loss_metric = metrics.Mean(name=\"reg\")\n",
    "        self.cdiv_loss_metric = metrics.Mean(name=\"cdiv\")\n",
    "        self.real_out_metric = metrics.Mean(name=\"real\")\n",
    "        self.fake_out_metric = metrics.Mean(name=\"fake\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [\n",
    "            self.loss_metric,\n",
    "            self.reg_loss_metric,\n",
    "            self.cdiv_loss_metric,\n",
    "            self.real_out_metric,\n",
    "            self.fake_out_metric,\n",
    "        ]\n",
    "\n",
    "    def train_step(self, real_imgs):\n",
    "        real_imgs += tf.random.normal(\n",
    "            shape=tf.shape(real_imgs), mean=0, stddev=NOISE\n",
    "        )\n",
    "        real_imgs = tf.clip_by_value(real_imgs, -1.0, 1.0)\n",
    "        fake_imgs = self.buffer.sample_new_exmps(\n",
    "            steps=STEPS, step_size=STEP_SIZE, noise=NOISE\n",
    "        )\n",
    "        inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n",
    "        with tf.GradientTape() as training_tape:\n",
    "            real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n",
    "            cdiv_loss = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n",
    "                real_out, axis=0\n",
    "            )\n",
    "            reg_loss = self.alpha * tf.reduce_mean(\n",
    "                real_out**2 + fake_out**2, axis=0\n",
    "            )\n",
    "            loss = cdiv_loss + reg_loss\n",
    "        grads = training_tape.gradient(loss, self.model.trainable_variables)\n",
    "        self.optimizer.apply_gradients(\n",
    "            zip(grads, self.model.trainable_variables)\n",
    "        )\n",
    "        self.loss_metric.update_state(loss)\n",
    "        self.reg_loss_metric.update_state(reg_loss)\n",
    "        self.cdiv_loss_metric.update_state(cdiv_loss)\n",
    "        self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n",
    "        self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n",
    "        return {m.name: m.result() for m in self.metrics}\n",
    "\n",
    "    def test_step(self, real_imgs):\n",
    "        batch_size = real_imgs.shape[0]\n",
    "        fake_imgs = (\n",
    "            tf.random.uniform((batch_size, IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
    "            * 2\n",
    "            - 1\n",
    "        )\n",
    "        inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n",
    "        real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n",
    "        cdiv = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n",
    "            real_out, axis=0\n",
    "        )\n",
    "        self.cdiv_loss_metric.update_state(cdiv)\n",
    "        self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n",
    "        self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n",
    "        return {m.name: m.result() for m in self.metrics[2:]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6337e801-eb59-4abe-84dc-9536cf4dc257",
   "metadata": {},
   "outputs": [],
   "source": [
    "ebm = EBM()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the EBM network <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile and train the model\n",
    "ebm.compile(\n",
    "    optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), run_eagerly=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ceca4de-f634-40ff-beb8-09ba42fd0f75",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img):\n",
    "        self.num_img = num_img\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        start_imgs = (\n",
    "            np.random.uniform(\n",
    "                size=(self.num_img, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)\n",
    "            )\n",
    "            * 2\n",
    "            - 1\n",
    "        )\n",
    "        generated_images = generate_samples(\n",
    "            ebm.model,\n",
    "            start_imgs,\n",
    "            steps=1000,\n",
    "            step_size=STEP_SIZE,\n",
    "            noise=NOISE,\n",
    "            return_img_per_step=False,\n",
    "        )\n",
    "        generated_images = generated_images.numpy()\n",
    "        display(\n",
    "            generated_images,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "        )\n",
    "\n",
    "        example_images = tf.concat(\n",
    "            random.choices(ebm.buffer.examples, k=10), axis=0\n",
    "        )\n",
    "        example_images = example_images.numpy()\n",
    "        display(\n",
    "            example_images, save_to=\"./output/example_img_%03d.png\" % (epoch)\n",
    "        )\n",
    "\n",
    "\n",
    "image_generator_callback = ImageGenerator(num_img=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "627c1387-f29a-4cce-85a8-0903c1890e23",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveModel(callbacks.Callback):\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        model.save_weights(\"./models/model.h5\")\n",
    "\n",
    "\n",
    "save_model_callback = SaveModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "ebm.fit(\n",
    "    x_train,\n",
    "    shuffle=True,\n",
    "    epochs=60,\n",
    "    validation_data=x_test,\n",
    "    callbacks=[\n",
    "        save_model_callback,\n",
    "        tensorboard_callback,\n",
    "        image_generator_callback,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
   "metadata": {},
   "source": [
    "## 4. Generate images <a name=\"generate\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_imgs = (\n",
    "    np.random.uniform(size=(10, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(start_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaf4b749-5f6e-4a12-863f-b0bbcd23549c",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "gen_img = generate_samples(\n",
    "    ebm.model,\n",
    "    start_imgs,\n",
    "    steps=1000,\n",
    "    step_size=STEP_SIZE,\n",
    "    noise=NOISE,\n",
    "    return_img_per_step=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eac707f6-0597-499c-9a52-7cade6724795",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "display(gen_img[-1].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8476aaa1-e0e7-44dc-a1fd-cc30344b8dcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = []\n",
    "for i in [0, 1, 3, 5, 10, 30, 50, 100, 300, 999]:\n",
    "    imgs.append(gen_img[i].numpy()[6])\n",
    "\n",
    "display(np.array(imgs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e05d5a0-e124-40d1-80ed-552260c6e350",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
